<!-- Copyright 2019 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================-->

<html>

<head>
  <title>TensorFlow.js Model Benchmark</title>
  <link href="https://fonts.googleapis.com/css?family=Roboto" rel="stylesheet">
  <link href="./main.css" rel="stylesheet">
  <script src="https://cdnjs.cloudflare.com/ajax/libs/dat-gui/0.7.2/dat.gui.min.js"></script>
</head>

<body>
  <h2>TensorFlow.js Model Benchmark</h2>
  <div id="modal-msg"></div>
  <div id="container">
    <div id="stats">
      <div class="box">
        <pre id="env"></pre>
      </div>
      <table class="table" id="timings">
        <thead>
          <tr>
            <th>Type</th>
            <th>Value</th>
          </tr>
        </thead>
        <tbody>
        </tbody>
      </table>
      <div class="box" id="perf-trendline-container">
        <div class="label">Inference times</div>
        <div class="trendline">
          <div class="yMax"></div>
          <div class="yMin"></div>
          <svg>
            <path></path>
          </svg>
        </div>
      </div>
    </div>
    <table class="table" id="kernels">
      <thead id="kernels-thead">
      </thead>
      <tbody></tbody>
    </table>
  </div>
  <script src="https://unpkg.com/@tensorflow/tfjs-core@latest/dist/tf-core.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-layers@latest/dist/tf-layers.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-converter@latest/dist/tf-converter.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-backend-wasm@latest/dist/tf-backend-wasm.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-automl@latest/dist/tf-automl.js"></script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/universal-sentence-encoder"></script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/posenet@2"></script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/body-pix@2"></script>

  <script src="./modelConfig.js"></script>
  <script src="./util.js"></script>
  <script>
    'use strict';

    const state = {
      numRuns: 50,
      benchmark: 'mobilenet_v2',
      run: (v) => {
        runBenchmark();
      },
      backend: 'wasm',
      kernelTiming: 'aggregate',
    };

    const modalDiv = document.getElementById('modal-msg');
    const timeTable = document.querySelector('#timings tbody');
    const envDiv = document.getElementById('env');
    const kernelsTableHead = document.getElementById('kernels-thead');
    const kernelTable = document.querySelector('#kernels tbody');

    let model, predict, chartWidth;

    async function showMsg(message) {
      if (message != null) {
        modalDiv.innerHTML = message + '...';
        modalDiv.style.display = 'block';
      } else {
        modalDiv.style.display = 'none';
      }
      await tf.nextFrame();
      await tf.nextFrame();
    }

    function showVersions() {
      envDiv.innerHTML = JSON.stringify({
        core: tf.version_core,
        layers: tf.version_layers,
        converter: tf.version_converter
      }, null, 2);
    }

    async function showEnvironment() {
      await tf.time(() => tf.add(tf.tensor1d([1]), tf.tensor1d([1])).data());
      envDiv.innerHTML += `<br/>${JSON.stringify(tf.env().features, null, 2)
        } `;
    }

    async function setupTable() {
      kernelsTableHead.innerText = '';
      kernelTable.innerHTML = '';
      await tf.nextFrame();
      const rows = ['<b>Kernel</b>', '<b>Time(ms)</b>'];
      if (state.kernelTiming === 'individual') {
        rows.push('<b>Inputs</b>', '<b>Output</b>');
        if (state.backend === 'webgl') {
          rows.push('<b>GPUPrograms</b>');
        }
      }
      appendRow(kernelsTableHead, ...rows);

      await tf.nextFrame();
    }

    function appendRow(tbody, ...cells) {
      const tr = document.createElement('tr');
      cells.forEach(c => {
        const td = document.createElement('td');
        if (c instanceof HTMLElement) {
          td.appendChild(c);
        } else {
          td.innerHTML = c;
        }
        tr.appendChild(td);
      });
      tbody.appendChild(tr);
    }

    async function warmUpAndRecordTime() {
      await showMsg('Warming up');
      const start = performance.now();
      let res = predict(model);
      if (res instanceof Promise) {
        res = await res;
      }

      if (res instanceof tf.Tensor) {
        const tmp = res;
        res = await res.data();
        tmp.dispose();
      }

      const elapsed = performance.now() - start;
      await showMsg(null);
      appendRow(timeTable, '1st inference', printTime(elapsed));
    }

    async function loadAndRecordTime(benchmark) {
      await showMsg('Loading the model');
      const start = performance.now();
      if (benchmark.model == null) {
        model = await benchmark.load();
        benchmark.model = model;
      } else {
        model = benchmark.model;
      }
      predict = benchmark.predictFunc();

      const elapsed = performance.now() - start;
      await showMsg(null);

      appendRow(timeTable, `<b> Benchmark:</b> ${state.benchmark} `,
        `<b> Runs:</b> ${state.numRuns} `);

      appendRow(timeTable, 'Model load', printTime(elapsed));
    }

    const chartHeight = 150;
    function populateTrendline(node, data, forceYMinToZero = false, yFormatter = d => d) {
      node.querySelector("svg").setAttribute("width", chartWidth);
      node.querySelector("svg").setAttribute("height", chartHeight);

      const yMax = Math.max(...data);
      let yMin = forceYMinToZero ? 0 : Math.min(...data);
      if (yMin === yMax) {
        yMin = 0;
      }

      node.querySelector(".yMin").textContent = yFormatter(yMin);
      node.querySelector(".yMax").textContent = yFormatter(yMax);

      const xIncrement = chartWidth / (data.length - 1);
      node.querySelector("path")
        .setAttribute("d", `M${data.map((d, i) => `${i * xIncrement},${chartHeight - ((d - yMin) / (yMax - yMin)) * chartHeight}`).join('L')} `);
    }

    async function measureAveragePredictTime() {
      await showMsg(`Running predict ${state.numRuns} times`);
      chartWidth = document.querySelector("#perf-trendline-container").getBoundingClientRect().width;

      const times = [];
      const numLeakedTensors = [];

      for (let i = 0; i < state.numRuns; i++) {
        const start = performance.now();
        const tensorsBefore = tf.memory().numTensors;
        let res = predict(model);
        if (res instanceof Promise) {
          res = await res;
        }

        if (res instanceof tf.Tensor) {
          const tmp = res;
          res = await res.data();
          tmp.dispose();
        }

        times.push(performance.now() - start);
        const memInfo = tf.memory();
        const leakedTensors = memInfo.numTensors - tensorsBefore;
        numLeakedTensors.push(leakedTensors);
      }

      const forceInferenceTrendYMinToZero = true;
      populateTrendline(document.querySelector("#perf-trendline-container"), times, forceInferenceTrendYMinToZero, printTime);

      await showMsg(null);
      const average = times.reduce((acc, curr) => acc + curr, 0) / times.length;
      const min = Math.min(...times);
      appendRow(timeTable, `Subsequent average(${state.numRuns} runs)`, printTime(average));
      appendRow(timeTable, 'Best time', printTime(min));
      appendRow(timeTable, 'Leaked tensors', numLeakedTensors[0]);
    }

    async function profileMemory() {
      await showMsg('Profile memory');
      const start = performance.now();
      let res;
      const data = await tf.profile(() => res = predict(model));
      if (res instanceof Promise) {
        res = await res;
      }

      if (res instanceof tf.Tensor) {
        const tmp = res;
        res = await res.data();
        tmp.dispose();
      }
      const elapsed = performance.now() - start;
      await showMsg(null);
      appendRow(timeTable, 'Peak memory', printMemory(data.peakBytes));
      appendRow(timeTable, '2nd inference', printTime(elapsed));
    }

    function showKernelTime(kernels) {
      const tbody = document.querySelector('#kernels tbody');
      if (state.kernelTiming === 'individual') {
        kernels.forEach(k => {
          const nameSpan = document.createElement('span');
          nameSpan.setAttribute('title', k.scopes.slice(0, -1).join(' --> '));
          nameSpan.textContent = k.scopes[k.scopes.length - 1];
          appendRow(tbody, nameSpan, k.time.toFixed(2), k.inputs, k.output, k.gpuProgramsInfo);
        });
      } else {
        const kernelTotalTime = {};
        kernels.forEach(k => {
          const kernelName = k.scopes[0];
          if (kernelTotalTime[kernelName] == null) {
            kernelTotalTime[kernelName] = 0;
          }
          kernelTotalTime[kernelName] += k.time;
        });

        const result = Object.keys(kernelTotalTime)
            .map(k => [k, kernelTotalTime[k]])
            .sort((a, b) => b[1] - a[1]);
        result.forEach(r => {
          const nameSpan = document.createElement('span');
          nameSpan.setAttribute('title', r[0]);
          nameSpan.textContent = r[0];
          appendRow(tbody, nameSpan, r[1].toFixed(2));
        });
      }
    }

    async function profileKernelTime() {
      await showMsg('Profiling kernels');
      _tfengine.ENV.set('DEBUG', true);
      const oldLog = console.log;
      let kernels = [];
      console.log = msg => {
        let parts = [];
        if (typeof msg === 'string') {
          parts = msg.split('\t').map(x => x.slice(2));
        }

        if (parts.length > 2) {
          // heuristic for determining whether we've caught a profiler
          // log statement as opposed to a regular console.log
          // TODO(https://github.com/tensorflow/tfjs/issues/563): return timing information as part of tf.profile
          const scopes = parts[0].trim()
            .split('||')
            .filter(s => s !== 'unnamed scope');
          kernels.push({
            scopes: scopes,
            time: Number.parseFloat(parts[1]),
            output: parts[2].trim(),
            inputs: parts[4],
            gpuProgramsInfo: parts[5]
          });
        } else {
          oldLog.call(oldLog, msg);
        }
      }
      let res = predict(model);
      if (res instanceof Promise) {
        res = await res;
      }

      if (res instanceof tf.Tensor) {
        const tmp = res;
        res = await res.data();
        tmp.dispose();
      }

      await showMsg(null);
      await sleep(10);
      kernels = kernels.sort((a, b) => b.time - a.time);
      appendRow(timeTable, 'Number of kernels', kernels.length);

      // Add an empty row at the end of a benchmark run
      appendRow(timeTable, '', '');
      showKernelTime(kernels);
      _tfengine.ENV.set('DEBUG', false);
      // Switch back to the old log;
      console.log = oldLog;
    }

    async function runBenchmark() {
      const benchmark = benchmarks[state.benchmark];
      await setupTable();
      await loadAndRecordTime(benchmark);
      await warmUpAndRecordTime();
      await showMsg('Waiting for GC');
      await sleep(1000);
      await profileMemory();
      await sleep(200);
      await measureAveragePredictTime();
      await sleep(200);
      if (state.backend != 'webgl' || queryTimerIsEnabled()) {
        await profileKernelTime();
      } else {
        showMsg('Skipping kernel times since query timer extension is not ' +
          'available. <br/> Use Chrome 70+.');
      };
    }

    async function onPageLoad() {
      var gui = new dat.gui.GUI();

      await tf.setBackend(state.backend);

      gui.add(state, 'numRuns');
      gui.add(state, 'benchmark', Object.keys(benchmarks));
      gui.add(state, 'backend', ['wasm', 'webgl', 'cpu']).onChange(backend => {
        tf.setBackend(backend);
      });
      gui.add(state, 'kernelTiming', ['aggregate', 'individual']);
      gui.add(state, 'run');

      showVersions();
      await showEnvironment();
    }

    onPageLoad();
  </script>
</body>

</html>
